Skip to content

Conversation

@manishamde
Copy link
Contributor

Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with @hirakendu that was pending due to decision tree optimizations and random forests work.

Ideally, boosting algorithms should work with any base learners. This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default.

Here is the task list:

  • Gradient boosting support
  • Pluggable loss functions
  • Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest.
  • Binary classification support
  • Support configurable checkpointing – This approach will avoid long lineage chains.
  • Create classification and regression APIs
  • Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting.
  • Unit Tests

Future work:

  • Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function.
  • BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed.

cc: @jkbradley @hirakendu @mengxr @etrain @atalwalkar @chouqin

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/21079/

@jkbradley
Copy link
Member

@manishamde Thanks for the WIP PR!

About classification, what points need to be discussed? Why is it more difficult to figure out than regression? (Also, I personally am not a big fan of the name "deviance" even though it is used in sklearn and in Friedman's paper. I prefer more descriptive names like LogLoss.)

Also, will this be generalized to support weighted weak hypotheses, common in most boosting algorithms?

For the final Model produced, should we use the same class for both random forests and gradient boosting? It could be a TreeEnsemble model (to be generalized later to a WeightedEnsemble model).

@jkbradley
Copy link
Member

@epahomov If you or your student are able to take a look at this, I'm sure @manishamde would appreciate it. This PR will hopefully be generalized to include Classification. It's nice in that it has infrastructure for multiple losses. Thank you!

@manishamde
Copy link
Contributor Author

@jkbradley

I meant multi-class classification. As you pointed out, binary classification should be similar to the regression case but I am not sure one can handle multi-class classification with one tree. We might have to resort to a one-vs-all strategy. I also agree with you on the naming convention -- log loss or negative binomial log likehood are better names.

Yes, I plan to handle weighted weak hypothesis. In fact, I needed it for something like AdaBoost and had to remove it before submitting this PR. Do you think it makes sense to do it along with this PR or do it in the subsequent AdaBoost PR?

I agree about the WeightedEnsemble model. Let me add it to the TODO list.

@jkbradley
Copy link
Member

@manishamde
Multi-class classification: Good point; I agree. I think this implementation can support binary, but we can do another to support multiclass. (For multiclass, I think the .OC error-correcting versions might be the best options.)

Weighted weak hypotheses: I am OK if this initial PR does not include weights, but then weights should be prioritized for the next update.

For the WeightedEnsemble, that generalization could be part of this PR or a follow-up.

Once this is ready, I'll be happy to help with testing (e.g., to set checkpointing intervals and general performance).

Thanks!

@manishamde
Copy link
Contributor Author

@jkbradley error-correcting codes will be a good option to support though we should also have a generic one-vs-all classifier. Yes, weight support will definitely be a part of the adaboost PR. Let's discuss the WeightedEnsemble as we get close to completing the PR.

Thanks for helping with the testing. I am currently implementing the TreeRdd caching and subsampling without replacement. After that, we can start testing in parallel along with further code development.

@jkbradley
Copy link
Member

Sounds good!

@SparkQA
Copy link

SparkQA commented Oct 6, 2014

QA tests have started for PR 2607 at commit 3973dd1.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Oct 6, 2014

QA tests have finished for PR 2607 at commit 3973dd1.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GradientBoosting (
    • case class BoostingStrategy(
    • trait Loss extends Serializable
    • class GradientBoostingModel(trees: Array[DecisionTreeModel], algo: Algo) extends Serializable

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/21351/Test FAILed.

@SparkQA
Copy link

SparkQA commented Oct 6, 2014

QA tests have started for PR 2607 at commit 78ed452.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Oct 7, 2014

QA tests have finished for PR 2607 at commit 78ed452.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GradientBoosting (
    • case class BoostingStrategy(
    • trait Loss extends Serializable
    • class GradientBoostingModel(trees: Array[DecisionTreeModel], algo: Algo) extends Serializable

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/21353/Test FAILed.

@SparkQA
Copy link

SparkQA commented Oct 7, 2014

QA tests have started for PR 2607 at commit 4784091.

  • This patch merges cleanly.

@manishamde
Copy link
Contributor Author

I have added stochastic gradient boosting by adding code for subsampling without replacement.

@manishamde
Copy link
Contributor Author

Here is an interesting design discussion:

For trees and RFs, we convert input: RDD[LabeledPoint] to treeInput: RDD[TreePoint] and persist it in memory since the same RDD is re-used during tree/forest building. However, for boosting, we need to construct a new tree every new iteration with a modified RDD where the input labels and weights are modified w.r.t. the original dataset. This leads to the repeated conversion from LabeledPoint to TreePoint every boosting iteration.

Here are a few approaches we can take to :
(1) Cache input: RDD[LabeledPoint] and then convert to RDD[TreePoint] during each iteration. We also need to decide whether the RDD[TreePoint] needs to be cached by default.
(2) Convert input: RDD[LabeledPoint] to treeInput: RDD[TreePoint] just once at the start. However, we will need another method to predict using TreePoint instances instead of the standard LabeledPoint. With this strategy, we cache TreePoint and re-use it every iteration and avoid the binning cost (and also possibly save memory) and avoid storing multiple RDDs in the memory.

I have implemented (1) but I think (2) will be worthwhile to try. Any suggestions?

@SparkQA
Copy link

SparkQA commented Oct 7, 2014

QA tests have finished for PR 2607 at commit 4784091.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class GradientBoosting (
    • case class BoostingStrategy(
    • trait Loss extends Serializable
    • class GradientBoostingModel(trees: Array[DecisionTreeModel], algo: Algo) extends Serializable

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/21354/Test FAILed.

@jkbradley
Copy link
Member

@manishamde About the 2 caching options, I agree with your decision to do (1) first. It would be nice to try (2) later on (another PR?), but I don't think it is too high-priority. Perhaps we can eventually have learning algs provide convertDatasetToInternalFormat() and predictUsingInternalFormat() methods (with less verbose names), once the standard API is in place.

@manishamde
Copy link
Contributor Author

@jkbradley Cool. I am sure we will see a definitely performance gain once we implement support for (2) once we have a standard API.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/22535/
Test PASSed.

@SparkQA
Copy link

SparkQA commented Oct 30, 2014

Test build #22536 has finished for PR 2607 at commit b4c1318.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/22536/
Test PASSed.

@jkbradley
Copy link
Member

Thanks for the updates; I'll take a look. I think that it will be very important to include checkpointing, but I am OK with adding it later on. (Since boosting is sequential, I could imagine it running for much longer than bagging/forest algorithms, so protecting against driver failure will be important.)

@manishamde
Copy link
Contributor Author

@jkbradley I agree with protection against driver failure for long sequential operations. However, in this case we will just be checkpointing partial models rather than the intermediate datasets similar to other iterative algorithms such as LR. Look forward to your feedback on the new logic.

@jkbradley
Copy link
Member

True, perhaps we'll need to checkpoint not just the labels but also the data itself for Spark to know how to resume training. Postponing checkpointing seems like a good idea for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think toString() should print the full model. toString should be concise, and toDebugString should print the full model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

@jkbradley
Copy link
Member

@manishamde The logic looks better (especially since you caught the learningRate bug!). After the API update (train*, BoostingStrategy, and making AbsoluteError and LogLoss private), I think this will be ready.

@manishamde
Copy link
Contributor Author

@jkbradley Thanks for the confirmation! I will now proceed to finish the rest of the tasks -- should be straightforward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be learningRate too, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the learning rate is applied after the first model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Friedman paper, the first "model" is just the average label (for squared error). I think it's fine to keep it as is; that way, running for just 1 iteration will behave reasonably.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

@SparkQA
Copy link

SparkQA commented Oct 31, 2014

Test build #22582 has started for PR 2607 at commit ff2a796.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Oct 31, 2014

Test build #22582 has finished for PR 2607 at commit ff2a796.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/22582/
Test PASSed.

@SparkQA
Copy link

SparkQA commented Oct 31, 2014

Test build #22596 has started for PR 2607 at commit 991c7b5.

  • This patch merges cleanly.

@manishamde
Copy link
Contributor Author

@jkbradley I cleaned up the public API based on our discussion. Going with a nested structure where we have to specify the weak learner parameters separately is cleaner but it puts the onus on us to write very good documentation.

I am tempted to keep AbsoluteError and LogLoss as is with the appropriate caveats in the documentation. A regression tree with mean prediction at the terminal nodes it not the best approximation (as pointed out by the TreeBoost paper) but it's not a bad one either. After all, we are just making approximations of the gradient at each step. Moreover, other weak learning algorithms (for example LR) will be hard to tailor towards each specific loss function. Thoughts?

@jkbradley
Copy link
Member

True, it's a good point about LR. OK, let's keep them with caveats, but hopefully run some tests to make sure they seem to be working. I'll make a pass tomorrow morning; thanks for the updates!

@SparkQA
Copy link

SparkQA commented Oct 31, 2014

Test build #22596 has finished for PR 2607 at commit 991c7b5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/22596/
Test PASSed.

@jkbradley
Copy link
Member

@manishamde LGTM! Thanks for updating the Strategy. I think this is ready to be merged, though I still plan to update the train* methods to eliminate the ones taking lots of parameters. In particular, I plan to:

  • Make builder methods for Strategy and Boosting Strategy to make them easy to construct from Java.
  • Eliminate the train* methods taking lots of parameters.
  • Write examples in Scala and Java to make sure everything is easy to do from Java.
    Does that sound OK?

Thanks very much for contributing GBT! It's a big step forward for MLlib.

CC: @mengxr

@manishamde
Copy link
Contributor Author

Thanks. Sounds good to me.

I tried to use the builder pattern to help for the Java use case but I guess we can handle it separately.

@manishamde
Copy link
Contributor Author

@mengxr Could we get this merged? :-)

@mengxr
Copy link
Contributor

mengxr commented Nov 1, 2014

I've merged this into master. Thanks @manishamde for contributing and @codedeft and @jkbradley for review!

@asfgit asfgit closed this in 8602195 Nov 1, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants